# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import re
import subprocess

import yaml
from op_gen import (
    PD_MANUAL_OP_LIST,
    OpCompatParser,
    OpInfoParser,
    to_pascal_case,
)

# white ops list whose kernel can automatically do type promotion.
# future will get this list from same place with dynamic graph.
type_promote_white_list = {
    "add": ["x", "y"],
    "subtract": ["x", "y"],
    "divide": ["x", "y"],
    "floor_divide": ["x", "y"],
    "trunc_divide": ["x", "y"],
    "elementwise_pow": ["x", "y"],
    "where": ["x", "y"],
    "equal": ["x", "y"],
    "not_equal": ["x", "y"],
    "less_than": ["x", "y"],
    "less_equal": ["x", "y"],
    "greater_than": ["x", "y"],
    "greater_equal": ["x", "y"],
    "logical_and": ["x", "y"],
    "logical_or": ["x", "y"],
    "logical_xor": ["x", "y"],
    "fmax": ["x", "y"],
    "fmin": ["x", "y"],
    "maximum": ["x", "y"],
    "minimum": ["x", "y"],
    "remainder": ["x", "y"],
    "huber_loss": ["input", "label"],
    "nextafter": ["x", "y"],
    "atan2": ["x", "y"],
    "multiply": ["x", "y"],
    "copysign": ["x", "y"],
    "cross": ["x", "y"],
}

type_promote_inplace_white_list = {
    "add_": ["x", "y"],
    "subtract_": ["x", "y"],
    "divide_": ["x", "y"],
    "floor_divide_": ["x", "y"],
    "trunc_divide_": ["x", "y"],
    "where_": ["x", "y"],
    "equal_": ["x", "y"],
    "not_equal_": ["x", "y"],
    "less_than_": ["x", "y"],
    "less_equal_": ["x", "y"],
    "greater_than_": ["x", "y"],
    "greater_equal_": ["x", "y"],
    "logical_and_": ["x", "y"],
    "logical_or_": ["x", "y"],
    "logical_xor_": ["x", "y"],
    "remainder_": ["x", "y"],
    "copysign_": ["x", "y"],
}

# ops support casting int tensor into float32 to do forward calculation
type_autocast_op_list = {
    "acos": ["x"],
    "acosh": ["x"],
    "asin": ["x"],
    "asinh": ["x"],
    "atan": ["x"],
    "atanh": ["x"],
    "cos": ["x"],
    "cosh": ["x"],
    "digamma": ["x"],
    "erf": ["x"],
    "erfinv": ["x"],
    "i0": ["x"],
    "i0e": ["x"],
    "i1": ["x"],
    "i1e": ["x"],
    "lgamma": ["x"],
    "logcumsumexp": ["x"],
    "logit": ["x"],
    "logsumexp": ["x"],
    "polygamma": ["x"],
    "reciprocal": ["x"],
    "rsqrt": ["x"],
    "sigmoid": ["x"],
    "sin": ["x"],
    "sinh": ["x"],
    "sqrt": ["x"],
    "stanh": ["x"],
    "tan": ["x"],
    "tanh": ["x"],
}

# ops support casting int tensor into float32 to do forward calculation,
# and it is valid to cast float32 gradient back to int tensor.
type_autocast_valid_grad_op_list = {}

PD_MANUAL_API_LIST = {
    'embedding_grad',
    'assign',
}

H_FILE_TEMPLATE = """

// This file is generated by "paddle/fluid/pir/dialect/op_generator/api_gen.py"
#pragma once

#include <vector>

#include "paddle/utils/optional.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h"

{body}

"""

CPP_FILE_TEMPLATE = """

// This file is generated by "paddle/fluid/pir/dialect/op_generator/api_gen.py"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/amp_utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/phi/common/type_promotion.h"
#include "paddle/fluid/pir/utils/type_promotion_utils.h"

{body}

"""


NAMESPACE_TEMPLATE = """
namespace {namespace} {{
{body}
}} // namespace {namespace}
"""


API_DECLARE_TEMPLATE = """
{ret_type} {api_name}({args});
"""


API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
    {inner_code}
}}

"""

API_INNER_CODE_TEMPLATE = """
    // AMP Logic
    {amp_logic}
    // Type Promotion Logic
    {type_promotion_logic}
    // Type Autocast Logic
    {type_autocast_logic}
    {check_data_type}
    {handle_optional_inputs}
    {in_combine}
    {compute_op}
    {set_null_type}
    {handle_optional_outputs}
    {out_split}
    {set_stop_gradient}
    {return_result}"""

SET_STOP_GRADIENT_TEMPLATE = """
    if (!egr::Controller::Instance().HasGrad()) {{
        SetStopGradient({value_list});
    }}
"""

AMP_LOGIC_TEMPLATE = """
    if (egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpLevel() != paddle::imperative::AmpLevel::O0){{
        VLOG(5) << "Check and Prepare For AMP: {op_name}";
        auto op_name = phi::TransToFluidOpName("{op_name}");
        paddle::small_vector<std::vector<pir::Value>, egr::kSlotSmallVectorSize> amp_values_vector = {{ {no_optional_inputs} }};
        {optional_inputs}
        auto amp_dst_dtype = paddle::imperative::GetAmpDestDtype(op_name, amp_values_vector);
        {new_inputs}
        {{
            paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0);
            return paddle::dialect::{op_name}({args});
        }}
    }}
"""

AMP_OPTIONAL_INPUTS_TEMPLATE = """if ({optional_input}) {{ amp_values_vector.push_back({vec_optional_input}); }}
"""

AMP_NEW_INPUTS_TEMPLATE = """auto new_{input} = paddle::imperative::{cast_func}("{input}", {input}, amp_dst_dtype, op_name);
"""

TYPE_PROMOTION_LOGIC_TEMPLATE = """
    auto op_name = phi::TransToFluidOpName("{op_name}");
    auto x_dtype = paddle::imperative::GetDataType({x});
    auto y_dtype = paddle::imperative::GetDataType({y});
    auto x_shape = pir::GetValueShape({x});
    auto y_shape = pir::GetValueShape({y});
    if (phi::NeedTypePromotion("{op_name}", x_dtype, y_dtype, x_shape, y_shape)) {{
    VLOG(5) << "got different data type, run type promotion automatically.";
    LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion automatically, this may cause data type been changed.";
    //{op_name}
    auto promotion_type = phi::GetPromoteDtype("{op_name}", x_dtype, y_dtype, x_shape, y_shape);

    {x_cast}
    auto new_{y} = pir::PromoteCast("{y}", {y}, promotion_type);

    return paddle::dialect::{op_name}({args});
  }}
"""

TYPE_AUTOCAST_LOGIC_TEMPLATE = """
    auto x_dtype = paddle::imperative::GetDataType({x});
    if (phi::NeedTypeAutoCast("{op_name}", x_dtype)) {{
    VLOG(5) << "math operation got integer input data type, run type autocast.";
    LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
    //{op_name}
    if (!{trace_backward}) {{ SetStopGradient({x}); }}
    auto new_{x} = pir::PromoteCast("{x}", {x}, phi::DataType::FLOAT32);
    return paddle::dialect::{op_name}({args});
  }}
"""

OP_DISPATCH_TEMPLATE = """
    if ({cond}) {{
        {inner_code}
    }}"""

OP_DISPATCH_ERROR_TEMPLATE = """
    PADDLE_THROW(common::errors::Unimplemented(
        "The kernel of ({op_name}) for input Value is unimplemented, please check the type of input Value."));"""


CHECK_DATA_TYPE_TEMPLATE = """
    {function}({inputs}, "{op_name}");"""

IF_TEMPLATE = """
    if ({condition}) {{
    {check_statement}
    }}"""

ELSE_IF_TEMPLATE = """
    else if ({condition}) {{
    {check_statement}
    }}"""

ELSE_TEMPLATE = """
    else {{
    {check_statement}
    }}"""


OPTIONAL_VECTOR_VALUE_INPUT_TEMPLATE = """
    paddle::optional<pir::Value> optional_{name};
    if (!{name}) {{
        optional_{name} = paddle::make_optional<pir::Value>(pir::Value());
    }} else {{
        auto optional_{name}_combine_op = ApiBuilder::Instance().GetBuilder()->Build<pir::CombineOp>({name}.get());
        optional_{name} = paddle::make_optional<pir::Value>(optional_{name}_combine_op.out());
    }}"""

OPTIONAL_VALUE_INPUT_TEMPLATE = """
    paddle::optional<pir::Value> optional_{name};
    if (!{name}) {{
        optional_{name} = paddle::make_optional<pir::Value>(pir::Value());
    }} else {{
        optional_{name} = {name};
    }}"""

OPTIONAL_VALUE_OUTPUT_TEMPLATE = """
    paddle::optional<pir::Value> optional_{name};
    if (!IsEmptyValue({op_name}_op.result({index}))) {{
        optional_{name} = paddle::make_optional<pir::Value>({op_name}_op.result({index}));
    }}"""

OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE = """
    paddle::optional<std::vector<pir::Value>> optional_{name};
    if (!IsEmptyValue({op_name}_op.result({index}))) {{
        auto optional_{name}_slice_op = ApiBuilder::Instance().GetBuilder()->Build<pir::SplitOp>({op_name}_op.result({index}));
        optional_{name} = paddle::make_optional<std::vector<pir::Value>>(optional_{name}_slice_op.outputs());
    }}"""

SET_NULL_TYPE_WITH_INPLACE_TEMPLATE = """
    if (!{input}) {{
        {op_name}_op.result({index}).set_type(pir::Type());
    }}"""

SET_NULL_TYPE_TEMPLATE = """
    pir::Type {op_name}_op_result_{index}_type = {op_name}_op.result({index}).type();
    if ({op_name}_op_result_{index}_type.isa<paddle::dialect::DenseTensorType>() && {op_name}_op_result_{index}_type.dyn_cast<paddle::dialect::DenseTensorType>().dims().size() == -1) {{
        {op_name}_op.result({index}).set_type(pir::Type());
    }}"""

COMBINE_OP_TEMPLATE = """
    auto {op_name} = builtin_combine({in_name}).defining_op<pir::CombineOp>();"""

SPLIT_OP_TEMPLATE = """
    auto {op_name} = ApiBuilder::Instance().GetBuilder()->Build<pir::SplitOp>({in_name});"""

COMPUTE_OP_TEMPLATE = """
    paddle::dialect::{op_class_name} {op_inst_name} = ApiBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""

OP_INPUT = 'pir::Value'
DENSE_TENSOR_TYPE = "paddle::dialect::DenseTensorType"
DATA_TYPE = "paddle::dialect::DataTypeAttribute"
VECTOR_TYPE = 'pir::VectorType'
INTARRAY_ATTRIBUTE = "paddle::dialect::IntArrayAttribute"

VALUE_TYPE_MAP = {
    'paddle::dialect::DenseTensorType': 'pir::Value',
    'paddle::dialect::SelectedRowsType': 'pir::Value',
    'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::Value>',
}
OPTIONAL_VALUE_TYPE_MAP = {
    'paddle::dialect::DenseTensorType': 'paddle::optional<pir::Value>',
    'paddle::dialect::SelectedRowsType': 'paddle::optional<pir::Value>',
    'pir::VectorType<paddle::dialect::DenseTensorType>': 'paddle::optional<std::vector<pir::Value>>',
}


def get_op_class_name(op_name):
    return to_pascal_case(op_name) + 'Op'


class CodeGen:
    def __init__(self) -> None:
        pass

    def _parse_yaml(self, op_yaml_files, op_compat_yaml_file):
        op_compat_parser = OpCompatParser(op_compat_yaml_file)
        op_info_items = []
        for yaml_file in op_yaml_files:
            with open(yaml_file, "r") as f:
                ops = yaml.safe_load(f)

            for op in ops:
                op_compat_item = op_compat_parser.get_compat(op['name'])
                if (
                    op_compat_item is None
                    and op['name'].endswith(('_grad', '_grad_'))
                    and 'forward' in op
                ):
                    op_compat_item = op_compat_parser.get_compat(
                        op['forward']['name']
                    )

                if (
                    op_compat_item is not None
                    and op_compat_item['op'] == "pow"
                    and 'scalar' in op_compat_item
                ):
                    op_compat_item = op_compat_item.pop('scalar')
                if 'support_tensor' in op.keys() and op['support_tensor']:
                    (
                        scalar_item,
                        int_array_item,
                    ) = op_compat_parser.parse_support_tensor(op)
                    op_compat_item['scalar'] = scalar_item
                    op_compat_item['int_array'] = int_array_item
                op_info_items.append(
                    OpInfoParser(op, op_compat_item, yaml_file)
                )
        return op_info_items

    def _need_skip(self, op_info, op_name):
        return (
            op_info.infer_meta_func is None and op_name not in PD_MANUAL_OP_LIST
        )

    def _is_optional_input(self, op_info, input_name):
        name_list = op_info.input_name_list
        optional_list = op_info.input_optional_list
        if (
            input_name in name_list
            and optional_list[name_list.index(input_name)] == 'true'
        ):
            return True
        return False

    def _is_optional_output(self, op_info, output_name):
        output_optional_list = op_info.output_optional_list
        output_name_list = op_info.output_name_list
        intermediate_list = op_info.output_intermediate_list
        output_index = output_name_list.index(output_name)
        if (
            intermediate_list[output_index] == 'false'
            and op_info.output_optional_list[output_index] == 'true'
        ):
            return True
        else:
            return False

    def _is_backward_op(self, op_info):
        op_names = op_info.op_phi_name
        for name in op_names:
            if name.endswith(('_grad', '_grad_')):
                return True
        else:
            return False

    def _is_optional_inplace_output(self, op_info, output_name):
        op_names = op_info.op_phi_name
        for name in op_names:
            if name.endswith(('_grad', '_grad_')):
                return False
        inplace_map = op_info.inplace_map
        input_optional_list = op_info.input_optional_list
        input_name_list = op_info.input_name_list
        if inplace_map is None:
            return False

        if output_name in inplace_map.keys():
            input_index = input_name_list.index(inplace_map[output_name])
            if input_optional_list[input_index] == 'true':
                return True
        return False

    def _need_optional_output(self, op_info, name):
        if self._is_optional_inplace_output(op_info, name):
            return True
        if self._is_backward_op(op_info) and self._is_optional_output(
            op_info, name
        ):
            return True
        return False

    # =====================================
    # Gen declare functions
    # =====================================
    def _gen_api_inputs(self, op_info):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        optional_list = op_info.input_optional_list
        assert len(name_list) == len(type_list) == len(optional_list)
        ret = []
        for name, type, optional in zip(name_list, type_list, optional_list):
            if optional == 'true':
                ret.append(f'const {OPTIONAL_VALUE_TYPE_MAP[type]}& {name}')
            else:
                ret.append(f'const {VALUE_TYPE_MAP[type]}& {name}')
        return ', '.join(ret)

    def _gen_api_attrs(
        self, op_info, with_default, is_mutable_attr, is_vector_mutable_attr
    ):
        name_list = op_info.attribute_name_list
        type_list = op_info.attribute_build_arg_type_list
        default_value_list = op_info.attribute_default_value_list
        mutable_name_list = op_info.mutable_attribute_name_list
        mutable_type_list = op_info.mutable_attribute_type_list
        assert len(name_list) == len(type_list) == len(default_value_list)
        no_mutable_attr = []
        mutable_attr = []
        for name, type, default_value in zip(
            name_list, type_list, default_value_list
        ):
            if is_mutable_attr and name in mutable_name_list:
                if (
                    mutable_type_list[mutable_name_list.index(name)][0]
                    == INTARRAY_ATTRIBUTE
                    and is_vector_mutable_attr
                ):
                    mutable_attr.append(f'std::vector<{OP_INPUT}> {name}')
                else:
                    mutable_attr.append(f'{OP_INPUT} {name}')
                continue
            if with_default and default_value is not None:
                if type in ['float', 'double']:
                    default_value = default_value.strip('"')
                no_mutable_attr.append(f'{type} {name} = {default_value}')
            else:
                no_mutable_attr.append(f'{type} {name}')
        return ', '.join(mutable_attr + no_mutable_attr)

    def _gen_api_args(
        self,
        op_info,
        with_default_attr,
        is_mutable_attr,
        is_vector_mutable_attr,
    ):
        inputs = self._gen_api_inputs(op_info)
        attrs = self._gen_api_attrs(
            op_info, with_default_attr, is_mutable_attr, is_vector_mutable_attr
        )
        return (inputs + ', ' + attrs).strip(', ')

    def _gen_ret_type(self, op_info):
        name_list = op_info.output_name_list
        type_list = op_info.output_type_list
        intermediate_list = op_info.output_intermediate_list
        assert len(name_list) == len(type_list) == len(intermediate_list)

        output_num = len(type_list) - intermediate_list.count('true')
        if output_num > 1:
            ret = []
            for name, type, intermediate in zip(
                name_list, type_list, intermediate_list
            ):
                if intermediate == 'true':
                    continue
                if self._need_optional_output(op_info, name):
                    ret.append(OPTIONAL_VALUE_TYPE_MAP[type])
                else:
                    ret.append(VALUE_TYPE_MAP[type])
            return 'std::tuple<{}>'.format(', '.join(ret))
        elif output_num == 1:
            index = intermediate_list.index('false')
            name = name_list[index]
            if self._need_optional_output(op_info, name):
                return OPTIONAL_VALUE_TYPE_MAP[type_list[index]]
            else:
                return VALUE_TYPE_MAP[type_list[index]]
        elif output_num == 0:
            return 'void'

    def _gen_one_declare(
        self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
    ):
        if op_info.is_sparse_op:
            op_name += "sp_" if op_name[-1] == "_" else "_sp"
        return API_DECLARE_TEMPLATE.format(
            ret_type=self._gen_ret_type(op_info),
            api_name=op_name,
            args=self._gen_api_args(
                op_info, True, is_mutable_attr, is_vector_mutable_attr
            ),
        )

    def _gen_h_file(self, op_info_items, namespaces, h_file_path):
        declare_str = ""
        for op_info in op_info_items:
            for op_name in op_info.op_phi_name:
                # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
                # is wrong, so temporarily skip the automatic generation of these APIs
                if (
                    self._need_skip(op_info, op_name)
                    or op_name in PD_MANUAL_API_LIST
                ):
                    continue
                declare_str += self._gen_one_declare(
                    op_info, op_name, False, False
                )
                if len(op_info.mutable_attribute_name_list) > 0:
                    declare_str += self._gen_one_declare(
                        op_info, op_name, True, False
                    )
                    if INTARRAY_ATTRIBUTE in {
                        type[0] for type in op_info.mutable_attribute_type_list
                    }:
                        declare_str += self._gen_one_declare(
                            op_info, op_name, True, True
                        )
        body = declare_str
        for namespace in reversed(namespaces):
            body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
        with open(h_file_path, 'w') as f:
            f.write(H_FILE_TEMPLATE.format(body=body))

    # =====================================
    # Gen impl functions
    # =====================================
    def _gen_handle_optional_inputs(self, op_info):
        name_list = op_info.input_name_list
        optional_list = op_info.input_optional_list
        type_list = op_info.input_type_list
        assert len(name_list) == len(optional_list) == len(type_list)
        ret = ""
        for name, optional, type in zip(name_list, optional_list, type_list):
            if optional == 'true':
                if VECTOR_TYPE in type:
                    ret += OPTIONAL_VECTOR_VALUE_INPUT_TEMPLATE.format(
                        name=name
                    )
                else:
                    ret += OPTIONAL_VALUE_INPUT_TEMPLATE.format(name=name)
        return ret

    def _gen_handle_optional_inplace_outputs(self, op_info, op_name):
        name_list = op_info.output_name_list
        type_list = op_info.output_type_list
        intermediate_list = op_info.output_intermediate_list
        ret = ""
        for i, (name, type, intermediate) in enumerate(
            zip(name_list, type_list, intermediate_list)
        ):
            if intermediate == 'true':
                continue
            if self._need_optional_output(op_info, name):
                if VECTOR_TYPE in type:
                    ret += OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE.format(
                        name=name,
                        op_name=op_name,
                        index=i,
                    )
                else:
                    ret += OPTIONAL_VALUE_OUTPUT_TEMPLATE.format(
                        name=name,
                        op_name=op_name,
                        index=i,
                    )
        return ret

    def _gen_set_null_type(self, op_info, op_name):
        name_list = op_info.output_name_list
        inplace_map = op_info.inplace_map

        ret = ""
        for i, out_name in enumerate(name_list):
            if self._is_optional_inplace_output(op_info, out_name):
                in_name = inplace_map[out_name]
                ret += SET_NULL_TYPE_WITH_INPLACE_TEMPLATE.format(
                    input=in_name, op_name=op_name, index=i
                )
            elif self._is_optional_output(op_info, out_name):
                ret += SET_NULL_TYPE_TEMPLATE.format(op_name=op_name, index=i)
        return ret

    def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        optional_list = op_info.input_optional_list
        assert len(name_list) == len(type_list) == len(optional_list)
        combine_op = ""
        combine_op_list = []
        for name, type, optional in zip(name_list, type_list, optional_list):
            if optional == 'false' and VECTOR_TYPE in type:
                op_name = f'{name}_combine_op'
                combine_op += COMBINE_OP_TEMPLATE.format(
                    op_name=op_name, in_name=name
                )
                combine_op_list.append(op_name)
            else:
                combine_op_list.append(None)

        if is_mutable_attr:
            name_list = op_info.mutable_attribute_name_list
            type_list = op_info.mutable_attribute_type_list
            assert len(name_list) == len(type_list)
            for name, type in zip(name_list, type_list):
                if type[0] == INTARRAY_ATTRIBUTE and is_vector_mutable_attr:
                    op_name = f'{name}_combine_op'
                    combine_op += COMBINE_OP_TEMPLATE.format(
                        op_name=op_name, in_name=name
                    )
                    combine_op_list.append(op_name)
                else:
                    combine_op_list.append(None)

        return combine_op, combine_op_list

    def _gen_compute_op_args(
        self, op_info, in_combine_op_list, is_mutable_attr
    ):
        input_name_list = op_info.input_name_list
        all_attr_list = op_info.attribute_name_list
        no_mutable_attr_list = op_info.non_mutable_attribute_name_list
        mutable_attr_list = op_info.mutable_attribute_name_list
        assert len(input_name_list) + len(mutable_attr_list) == len(
            in_combine_op_list
        ) or len(input_name_list) == len(in_combine_op_list)
        ret = []
        if is_mutable_attr:
            name_list = input_name_list + mutable_attr_list
        else:
            name_list = input_name_list

        for input_name, combine_op in zip(name_list, in_combine_op_list):
            if combine_op is None:
                if self._is_optional_input(op_info, input_name):
                    ret.append(f'optional_{input_name}.get()')
                else:
                    ret.append(input_name)
            else:
                ret.append(f'{combine_op}.out()')
        if is_mutable_attr:
            ret += list(no_mutable_attr_list)
        else:
            ret += list(all_attr_list)
        return ', '.join(ret)

    def _gen_compute_op(
        self, op_info, op_name, in_combine_op_list, is_mutable_attr
    ):
        if op_info.is_sparse_op:
            op_class_name = to_pascal_case(op_name) + 'SpOp'
        else:
            op_class_name = to_pascal_case(op_name) + 'Op'
        op_inst_name = op_name + '_op'
        return (
            COMPUTE_OP_TEMPLATE.format(
                op_class_name=op_class_name,
                op_inst_name=op_inst_name,
                args=self._gen_compute_op_args(
                    op_info, in_combine_op_list, is_mutable_attr
                ),
            ),
            op_inst_name,
        )

    def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
        name_list = op_info.output_name_list
        type_list = op_info.output_type_list
        intermediate_list = op_info.output_intermediate_list
        optional_list = op_info.output_optional_list
        assert (
            len(name_list)
            == len(type_list)
            == len(intermediate_list)
            == len(optional_list)
        )

        split_op_str = ""
        ret_list = []
        for i, (name, type, intermediate) in enumerate(
            zip(name_list, type_list, intermediate_list)
        ):
            if intermediate == 'true':
                continue
            if self._need_optional_output(op_info, name):
                ret_list.append(f'optional_{name}')
            elif VECTOR_TYPE in type:
                split_op_name = f'{name}_split_op'
                split_op_str += SPLIT_OP_TEMPLATE.format(
                    op_name=split_op_name, in_name=f'{op_inst_name}.result({i})'
                )
                ret_list.append(f'{split_op_name}.outputs()')
            else:
                ret_list.append(f'{op_inst_name}.result({i})')
        return split_op_str, ret_list

    def _gen_set_stop_gradient(self, ret_list):
        if len(ret_list) > 0:
            return SET_STOP_GRADIENT_TEMPLATE.format(
                value_list=', '.join(ret_list)
            )
        else:
            return ''

    def _gen_return_result(self, ret_list):
        if len(ret_list) > 1:
            return 'return std::make_tuple({});'.format(', '.join(ret_list))
        elif len(ret_list) == 1:
            return f'return {ret_list[0]};'
        elif len(ret_list) == 0:
            return 'return;'

    def _gen_amp_no_optional_inputs(self, op_info):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        no_optional_inputs = []
        for name, type in zip(name_list, type_list):
            if self._is_optional_input(op_info, name):
                continue
            if VECTOR_TYPE in type:
                no_optional_inputs.append(name)
            else:
                no_optional_inputs.append('{' + name + '}')
        return ', '.join(no_optional_inputs)

    def _gen_amp_optional_inputs(self, op_info):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        ret = ''
        for name, type in zip(name_list, type_list):
            if not self._is_optional_input(op_info, name):
                continue
            if VECTOR_TYPE in type:
                ret += AMP_OPTIONAL_INPUTS_TEMPLATE.format(
                    optional_input=name, vec_optional_input='*' + name
                )
            else:
                ret += AMP_OPTIONAL_INPUTS_TEMPLATE.format(
                    optional_input=name, vec_optional_input='{ *' + name + ' }'
                )
        return ret

    def _gen_amp_new_inputs(self, op_info, op_name):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        ret = ''
        for name, type in zip(name_list, type_list):
            cast_func = 'AmpAutoCasts' if VECTOR_TYPE in type else 'AmpAutoCast'
            ret += AMP_NEW_INPUTS_TEMPLATE.format(
                input=name, cast_func=cast_func
            )
        return ret

    def _gen_amp_args(self, op_info, is_mutable_attr):
        if is_mutable_attr:
            input_list = op_info.input_name_list
            attr_list = (
                op_info.mutable_attribute_name_list
                + op_info.non_mutable_attribute_name_list
            )
        else:
            input_list = op_info.input_name_list
            attr_list = op_info.attribute_name_list
        args = ['new_' + input for input in input_list] + attr_list
        return ', '.join(args)

    def _gen_amp_logic(self, op_info, op_name, is_mutable_attr):
        input_list = op_info.input_name_list
        if not input_list:
            return (
                f'VLOG(5) << " No AMP for {op_name} because it has no input. ";'
            )
        if op_name.endswith(('_grad', '_grad_')):
            return 'VLOG(5) << " No AMP for grad apis. ";'
        if op_name.endswith('_') or op_name == 'cast':
            return f'VLOG(5) << "No AMP for {op_name} because it is a inplace or cast api.";'
        if op_info.is_sparse_op:
            op_name += "sp_" if op_name[-1] == "_" else "_sp"
        return AMP_LOGIC_TEMPLATE.format(
            op_name=op_name,
            no_optional_inputs=self._gen_amp_no_optional_inputs(op_info),
            optional_inputs=self._gen_amp_optional_inputs(op_info),
            new_inputs=self._gen_amp_new_inputs(op_info, op_name),
            args=self._gen_amp_args(op_info, is_mutable_attr),
        )

    def _gen_type_promotion_args(self, op_info, op_name):
        type_promote_inputs_call_list = []
        inplace_map = op_info.inplace_map
        for name in op_info.input_name_list:
            if op_name in type_promote_white_list:
                if name in type_promote_white_list[op_name]:
                    type_promote_inputs_call_list.append(f"new_{name}")
                else:
                    type_promote_inputs_call_list.append(f"{name}")
            elif op_name in type_promote_inplace_white_list:
                if name == type_promote_inplace_white_list[op_name][0]:
                    type_promote_inputs_call_list.append(f"{name}")
                elif name in type_promote_inplace_white_list[op_name]:
                    type_promote_inputs_call_list.append(f"new_{name}")
                else:
                    type_promote_inputs_call_list.append(f"{name}")

        attr_list = op_info.attribute_name_list
        args = type_promote_inputs_call_list + attr_list
        return ', '.join(args)

    def _gen_type_promotion_logic(self, op_info, op_name):
        input_list = op_info.input_name_list
        if op_name in type_promote_white_list:
            x = type_promote_white_list[op_name][0]
            y = type_promote_white_list[op_name][1]

            type_promote_inputs_call_args_str = self._gen_type_promotion_args(
                op_info, op_name
            )

            x_cast = (
                f'auto new_{x} = pir::PromoteCast("{x}", {x}, promotion_type);'
            )
            if op_info.is_sparse_op:
                op_name += "sp_" if op_name[-1] == "_" else "_sp"
            type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
                op_name=op_name,
                x=x,
                y=y,
                x_cast=x_cast,
                args=type_promote_inputs_call_args_str,
            )
        elif op_name in type_promote_inplace_white_list:
            x = type_promote_inplace_white_list[op_name][0]
            y = type_promote_inplace_white_list[op_name][1]

            type_promote_inputs_call_args_str = self._gen_type_promotion_args(
                op_info, op_name
            )

            x_cast = f'pir::PromoteCastInplace("{x}", {x}, promotion_type);'

            type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
                op_name=op_name,
                x=x,
                y=y,
                x_cast=x_cast,
                args=type_promote_inputs_call_args_str,
            )
        else:
            type_promotion_logic_str = (
                f'\n VLOG(5) << " No Type Promotion for {op_name} api. "; '
            )

        return type_promotion_logic_str

    def _gen_type_autocast_args(self, op_info, op_name):
        type_autocast_inputs_call_list = []
        for name in op_info.input_name_list:
            if op_name in type_autocast_op_list:
                if name in type_autocast_op_list[op_name]:
                    type_autocast_inputs_call_list.append(f"new_{name}")
                else:
                    type_autocast_inputs_call_list.append(f"{name}")

        attr_list = op_info.attribute_name_list
        args = type_autocast_inputs_call_list + attr_list
        return ', '.join(args)

    def _gen_type_autocast_logic(self, op_info, op_name):
        if op_name in type_autocast_op_list:
            x = type_autocast_op_list[op_name][0]

            type_autocast_inputs_call_args_str = self._gen_type_autocast_args(
                op_info, op_name
            )
            trace_backward = op_name in type_autocast_valid_grad_op_list
            trace_backward = str(trace_backward).lower()

            if op_info.is_sparse_op:
                op_name += "sp_" if op_name[-1] == "_" else "_sp"
            type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
                op_name=op_name,
                x=x,
                trace_backward=trace_backward,
                args=type_autocast_inputs_call_args_str,
            )
        else:
            type_autocast_logic_str = (
                f'\n VLOG(5) << " No Type Autocast for {op_name} api. "; '
            )

        return type_autocast_logic_str

    def _gen_check_data_type(self, op_info, op_name):
        mapping_input_name_to_type = dict(
            zip(op_info.input_name_list, op_info.input_type_list)
        )
        mapping_attr_name_to_type = dict(
            zip(op_info.attribute_name_list, op_info.attribute_type_list)
        )

        mapping_name_to_type = {
            **mapping_input_name_to_type,
            **mapping_attr_name_to_type,
        }

        mapping_input_name_to_optional = dict(
            zip(op_info.input_name_list, op_info.input_optional_list)
        )

        if (
            op_name in ["real_grad", "imag_grad"]
            or len(mapping_name_to_type) == 0
        ):
            return ""
        try:
            data_type_candidates = op_info.kernel_map['data_type']['candidates']
        except (KeyError, TypeError):
            data_type_candidates = None

        mapping_type_to_function_name = {
            f"{VECTOR_TYPE}<{DENSE_TENSOR_TYPE}>": 'CheckVectorOfValueDataType',
            DENSE_TENSOR_TYPE: 'CheckValueDataType',
            DATA_TYPE: 'CheckDataType',
        }

        if data_type_candidates is None or len(data_type_candidates) == 0:
            if len(op_info.input_name_list) == 0:
                return ""
            ret = ""
            for name in op_info.input_name_list[::-1]:
                type = mapping_input_name_to_type[name]
                optional = mapping_input_name_to_optional[name]
                if (
                    function_name := mapping_type_to_function_name.get(
                        type, None
                    )
                ) is None:
                    continue

                if optional == 'false':
                    if ret == "":
                        return CHECK_DATA_TYPE_TEMPLATE.format(
                            function=function_name,
                            inputs=f'{name}, "{name}"',
                            op_name=op_name,
                        )
                    else:
                        ret += ELSE_TEMPLATE.format(
                            check_statement=CHECK_DATA_TYPE_TEMPLATE.format(
                                function=function_name,
                                inputs=f'{name}, "{name}"',
                                op_name=op_name,
                            ).strip("\n")
                        )
                        return ret
                else:
                    if ret == "":
                        template = IF_TEMPLATE
                    else:
                        template = ELSE_IF_TEMPLATE

                    ret += template.format(
                        condition=name,
                        check_statement=CHECK_DATA_TYPE_TEMPLATE.format(
                            function=function_name,
                            inputs=f'{name}.get(), "{name}"',
                            op_name=op_name,
                        ).strip("\n"),
                    )
            return ret

        elif len(data_type_candidates) == 1:
            name = data_type_candidates[0]
            if name not in mapping_name_to_type:
                return ""
            type = mapping_name_to_type[name]
            if (
                function_name := mapping_type_to_function_name.get(type, None)
            ) is None:
                return ""
            return CHECK_DATA_TYPE_TEMPLATE.format(
                function=function_name,
                inputs=f'{name}, "{name}"',
                op_name=op_name,
            )
        elif len(data_type_candidates) == 2:
            dtype_name = data_type_candidates[0]
            value_name = data_type_candidates[1]
            dtype_type = mapping_name_to_type.get(dtype_name, None)
            value_type = mapping_name_to_type.get(value_name, None)
            if DENSE_TENSOR_TYPE != value_type or DATA_TYPE != dtype_type:
                return ""
            function_name = 'CheckDataTypeOrValue'
            return CHECK_DATA_TYPE_TEMPLATE.format(
                function=function_name,
                inputs=f'{dtype_name}, "{dtype_name}", {value_name}, "{value_name}"',
                op_name=op_name,
            )
        return ""

    def _gen_one_impl(
        self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
    ):
        ret = ''
        dispatch_kernel = None
        if op_info.kernel_map and 'dispatch' in op_info.kernel_map:
            dispatch_kernel = op_info.kernel_map['dispatch']
        if dispatch_kernel and len(dispatch_kernel.keys()) > 1:
            api_inner_code = ''
            for kernel_name in dispatch_kernel.keys():
                dispatch_input_type = dispatch_kernel[kernel_name][0]
                input_name = op_info.input_name_list
                input_optional = op_info.input_optional_list
                cond_list = []
                for i, type in enumerate(dispatch_input_type):
                    name = input_name[i]
                    optional = input_optional[i]
                    if type == 'dense':
                        if optional == 'true':
                            cond_list.append(
                                f'(!{name} || {name}->type().isa<paddle::dialect::DenseTensorType>())'
                            )
                        else:
                            cond_list.append(
                                f'{name}.type().isa<paddle::dialect::DenseTensorType>()'
                            )
                    elif type == 'selected_rows':
                        if optional == 'true':
                            cond_list.append(
                                f'(!{name} || {name}->type().isa<paddle::dialect::SelectedRowsType>())'
                            )
                        else:
                            cond_list.append(
                                f'{name}.type().isa<paddle::dialect::SelectedRowsType>()'
                            )
                    elif type == 'sparse_coo':
                        if optional == 'true':
                            cond_list.append(
                                f'(!{name} || {name}->type().isa<paddle::dialect::SparseCooTensorType>())'
                            )
                        else:
                            cond_list.append(
                                f'{name}.type().isa<paddle::dialect::SparseCooTensorType>()'
                            )
                    elif type == 'sparse_csr':
                        if optional == 'true':
                            cond_list.append(
                                f'(!{name} || {name}->type().isa<paddle::dialect::SparseCsrTensorType>())'
                            )
                        else:
                            cond_list.append(
                                f'{name}.type().isa<paddle::dialect::SparseCsrTensorType>()'
                            )
                ret_type = self._gen_ret_type(op_info)
                in_combine, in_combine_op_list = self._gen_in_combine(
                    op_info, is_mutable_attr, is_vector_mutable_attr
                )
                if op_name.endswith('_') and not kernel_name.endswith('_'):
                    kernel_name = kernel_name + '_'
                compute_op, op_inst_name = self._gen_compute_op(
                    op_info, kernel_name, in_combine_op_list, is_mutable_attr
                )
                if ret_type == 'void':
                    compute_op += f' (void){op_inst_name};'

                out_split, ret_list = self._gen_out_split_and_ret_list(
                    op_info, op_inst_name
                )
                if_inner_code = API_INNER_CODE_TEMPLATE.format(
                    amp_logic=self._gen_amp_logic(
                        op_info, op_name, is_mutable_attr
                    ),
                    type_promotion_logic=self._gen_type_promotion_logic(
                        op_info, op_name
                    ),
                    type_autocast_logic=self._gen_type_autocast_logic(
                        op_info, op_name
                    ),
                    check_data_type=self._gen_check_data_type(
                        op_info, kernel_name
                    ),
                    handle_optional_inputs=self._gen_handle_optional_inputs(
                        op_info
                    ),
                    in_combine=in_combine,
                    compute_op=compute_op,
                    handle_optional_outputs=self._gen_handle_optional_inplace_outputs(
                        op_info, kernel_name
                    ),
                    set_null_type=self._gen_set_null_type(op_info, kernel_name),
                    out_split=out_split,
                    set_stop_gradient=self._gen_set_stop_gradient(ret_list),
                    return_result=self._gen_return_result(ret_list),
                )
                if_inner_code = if_inner_code.split('\n')
                if_inner_code = '\n'.join(
                    ['    ' + code for code in if_inner_code]
                )

                api_inner_code += OP_DISPATCH_TEMPLATE.format(
                    cond=' && '.join(cond_list), inner_code=if_inner_code
                )
            if op_info.is_sparse_op:
                op_name += "sp_" if op_name[-1] == "_" else "_sp"
            api_inner_code += OP_DISPATCH_ERROR_TEMPLATE.format(op_name=op_name)
            ret = API_IMPL_TEMPLATE.format(
                ret_type=ret_type,
                api_name=op_name,
                args=self._gen_api_args(
                    op_info, False, is_mutable_attr, is_vector_mutable_attr
                ),
                inner_code=api_inner_code,
            )

        else:
            ret_type = self._gen_ret_type(op_info)
            in_combine, in_combine_op_list = self._gen_in_combine(
                op_info, is_mutable_attr, is_vector_mutable_attr
            )
            compute_op, op_inst_name = self._gen_compute_op(
                op_info, op_name, in_combine_op_list, is_mutable_attr
            )
            if ret_type == 'void':
                compute_op += f' (void){op_inst_name};'

            out_split, ret_list = self._gen_out_split_and_ret_list(
                op_info, op_inst_name
            )

            kernel_name = (
                next(iter(dispatch_kernel.keys()))
                if dispatch_kernel and len(dispatch_kernel.keys()) == 1
                else op_name
            )
            if op_name.endswith('_') and not kernel_name.endswith('_'):
                kernel_name = kernel_name + '_'
            api_inner_code = API_INNER_CODE_TEMPLATE.format(
                amp_logic=self._gen_amp_logic(
                    op_info, op_name, is_mutable_attr
                ),
                type_promotion_logic=self._gen_type_promotion_logic(
                    op_info, op_name
                ),
                type_autocast_logic=self._gen_type_autocast_logic(
                    op_info, op_name
                ),
                check_data_type=self._gen_check_data_type(op_info, kernel_name),
                handle_optional_inputs=self._gen_handle_optional_inputs(
                    op_info
                ),
                in_combine=in_combine,
                compute_op=compute_op,
                handle_optional_outputs=self._gen_handle_optional_inplace_outputs(
                    op_info, op_name
                ),
                set_null_type=self._gen_set_null_type(op_info, op_name),
                out_split=out_split,
                set_stop_gradient=self._gen_set_stop_gradient(ret_list),
                return_result=self._gen_return_result(ret_list),
            )
            if op_info.is_sparse_op:
                op_name += "sp_" if op_name[-1] == "_" else "_sp"
            ret = API_IMPL_TEMPLATE.format(
                ret_type=ret_type,
                api_name=op_name,
                args=self._gen_api_args(
                    op_info, False, is_mutable_attr, is_vector_mutable_attr
                ),
                inner_code=api_inner_code,
            )

        ret = re.sub(r' +\n', "", ret)
        return ret

    def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
        impl_str = ""
        for op_info in op_info_items:
            for op_name in op_info.op_phi_name:
                # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
                # is wrong, so temporarily skip the automatic generation of these APIs
                if (
                    self._need_skip(op_info, op_name)
                    or op_name in PD_MANUAL_API_LIST
                ):
                    continue
                impl_str += self._gen_one_impl(op_info, op_name, False, False)
                if len(op_info.mutable_attribute_name_list) > 0:
                    impl_str += self._gen_one_impl(
                        op_info, op_name, True, False
                    )
                    if INTARRAY_ATTRIBUTE in {
                        type[0] for type in op_info.mutable_attribute_type_list
                    }:
                        impl_str += self._gen_one_impl(
                            op_info, op_name, True, True
                        )
        body = impl_str
        for namespace in reversed(namespaces):
            body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
        with open(cpp_file_path, 'w') as f:
            f.write(CPP_FILE_TEMPLATE.format(body=body))

    def gen_h_and_cpp_file(
        self,
        op_yaml_files,
        op_compat_yaml_file,
        namespaces,
        h_file_path,
        cpp_file_path,
    ):
        if os.path.exists(h_file_path):
            os.remove(h_file_path)
        if os.path.exists(cpp_file_path):
            os.remove(cpp_file_path)
        op_info_items = self._parse_yaml(op_yaml_files, op_compat_yaml_file)
        self._gen_h_file(op_info_items, namespaces, h_file_path)
        self._gen_cpp_file(op_info_items, namespaces, cpp_file_path)
        try:
            subprocess.run(['clang-format', '-style=Google', '-i', h_file_path])
            subprocess.run(
                ['clang-format', '-style=Google', '-i', cpp_file_path]
            )
        except Exception as e:
            pass


def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Generate Dialect API Files By Yaml'
    )
    parser.add_argument('--op_yaml_files', type=str)
    parser.add_argument('--op_compat_yaml_file', type=str)
    parser.add_argument('--namespaces', type=str)
    parser.add_argument('--api_def_h_file', type=str)
    parser.add_argument('--api_def_cc_file', type=str)
    return parser.parse_args()


if __name__ == '__main__':
    args = ParseArguments()

    op_yaml_files = args.op_yaml_files.split(",")
    op_compat_yaml_file = args.op_compat_yaml_file
    if args.namespaces is not None:
        namespaces = args.namespaces.split(",")
        api_def_h_file = args.api_def_h_file
        api_def_cc_file = args.api_def_cc_file

        code_gen = CodeGen()
        code_gen.gen_h_and_cpp_file(
            op_yaml_files,
            op_compat_yaml_file,
            namespaces,
            api_def_h_file,
            api_def_cc_file,
        )
